#!/usr/bin/env python3

import ast
import copy
import ipaddress
import os
import sys
import subprocess
import syslog
import signal

import jinja2
from sonic_py_common import device_info
from swsscommon.swsscommon import ConfigDBConnector, DBConnector, Table

# FILE
PAM_AUTH_CONF = "/etc/pam.d/common-auth-sonic"
PAM_AUTH_CONF_TEMPLATE = "/usr/share/sonic/templates/common-auth-sonic.j2"
NSS_TACPLUS_CONF = "/etc/tacplus_nss.conf"
NSS_TACPLUS_CONF_TEMPLATE = "/usr/share/sonic/templates/tacplus_nss.conf.j2"
NSS_RADIUS_CONF = "/etc/radius_nss.conf"
NSS_RADIUS_CONF_TEMPLATE = "/usr/share/sonic/templates/radius_nss.conf.j2"
PAM_RADIUS_AUTH_CONF_TEMPLATE = "/usr/share/sonic/templates/pam_radius_auth.conf.j2"
NSS_CONF = "/etc/nsswitch.conf"
ETC_PAMD_SSHD = "/etc/pam.d/sshd"
ETC_PAMD_LOGIN = "/etc/pam.d/login"
PAM_LIMITS_CONF_TEMPLATE = "/usr/share/sonic/templates/pam_limits.j2"
LIMITS_CONF_TEMPLATE = "/usr/share/sonic/templates/limits.conf.j2"
PAM_LIMITS_CONF = "/etc/pam.d/pam-limits-conf"
LIMITS_CONF = "/etc/security/limits.conf"

# TACACS+
TACPLUS_SERVER_PASSKEY_DEFAULT = ""
TACPLUS_SERVER_TIMEOUT_DEFAULT = "5"
TACPLUS_SERVER_AUTH_TYPE_DEFAULT = "pap"

# RADIUS
RADIUS_SERVER_AUTH_PORT_DEFAULT = "1812"
RADIUS_SERVER_PASSKEY_DEFAULT = ""
RADIUS_SERVER_RETRANSMIT_DEFAULT = "3"
RADIUS_SERVER_TIMEOUT_DEFAULT = "5"
RADIUS_SERVER_AUTH_TYPE_DEFAULT = "pap"
RADIUS_PAM_AUTH_CONF_DIR = "/etc/pam_radius_auth.d/"

# MISC Constants
CFG_DB = "CONFIG_DB"
STATE_DB = "STATE_DB"
HOSTCFGD_MAX_PRI = 10  # Used to enforce ordering b/w daemons under Hostcfgd
DEFAULT_SELECT_TIMEOUT = 1000


def safe_eval(val, default_value=False):
    """ Safely evaluate the expression, without raising an exception """
    try:
        ret = ast.literal_eval(val)
    except ValueError:
        ret = default_value
    return ret


def signal_handler(sig, frame):
    if sig == signal.SIGHUP:
        syslog.syslog(syslog.LOG_INFO, "HostCfgd: signal 'SIGHUP' is caught and ignoring..")
    elif sig == signal.SIGINT:
        syslog.syslog(syslog.LOG_INFO, "HostCfgd: signal 'SIGINT' is caught and exiting...")
        sys.exit(128 + sig)
    elif sig == signal.SIGTERM:
        syslog.syslog(syslog.LOG_INFO, "HostCfgd: signal 'SIGTERM' is caught and exiting...")
        sys.exit(128 + sig)
    else:
        syslog.syslog(syslog.LOG_INFO, "HostCfgd: invalid signal - ignoring..")


def run_cmd(cmd, log_err=True, raise_exception=False):
    try:
        subprocess.check_call(cmd, shell=True)
    except Exception as err:
        if log_err:
            syslog.syslog(syslog.LOG_ERR, "{} - failed: return code - {}, output:\n{}"
                  .format(err.cmd, err.returncode, err.output))
        if raise_exception:
            raise


def is_true(val):
    if val == 'True' or val == 'true':
        return True
    else:
        return False


def is_vlan_sub_interface(ifname):
    ifname_split = ifname.split(".")
    return (len(ifname_split) == 2)


def sub(l, start, end):
    return l[start:end]


def obfuscate(data):
    if data:
        return data[0] + '*****'
    else:
        return data

def get_pid(procname):
    for dirname in os.listdir('/proc'):
        if dirname == 'curproc':
            continue
        try:
            with open('/proc/{}/cmdline'.format(dirname), mode='r') as fd:
                content = fd.read()
        except Exception as ex:
            continue
        if procname in content:
            return dirname
    return ""

class Feature(object):
    """ Represents a feature configuration from CONFIG_DB data. """

    def __init__(self, feature_name, feature_cfg, device_config=None):
        """ Initialize Feature object based on CONFIG_DB data.

        Args:
            feature_name (str): Feature name string
            feature_cfg (dict): Feature CONFIG_DB configuration
            deviec_config (dict): DEVICE_METADATA section of CONFIG_DB
        """

        self.name = feature_name
        self.state = self._get_target_state(feature_cfg.get('state'), device_config or {})
        self.auto_restart = feature_cfg.get('auto_restart', 'disabled')
        self.has_timer = safe_eval(feature_cfg.get('has_timer', 'False'))
        self.has_global_scope = safe_eval(feature_cfg.get('has_global_scope', 'True'))
        self.has_per_asic_scope = safe_eval(feature_cfg.get('has_per_asic_scope', 'False'))

    def _get_target_state(self, state_configuration, device_config):
        """ Returns the target state for the feature by rendering the state field as J2 template.

        Args:
            state_configuration (str): State configuration from CONFIG_DB
            deviec_config (dict): DEVICE_METADATA section of CONFIG_DB
        Returns:
            (str): Target feature state
        """

        if state_configuration is None:
            return None

        template = jinja2.Template(state_configuration)
        target_state = template.render(device_config)
        if target_state not in ('enabled', 'disabled', 'always_enabled', 'always_disabled'):
            raise ValueError('Invalid state rendered for feature {}: {}'.format(self.name, target_state))
        return target_state

    def compare_state(self, feature_name, feature_cfg):
        if self.name != feature_name or not isinstance(feature_cfg, dict):
            return False

        if self.state != feature_cfg.get('state', ''):
            return False
        return True


class FeatureHandler(object):
    """ Handles FEATURE table updates. """

    SYSTEMD_SYSTEM_DIR = '/etc/systemd/system/'
    SYSTEMD_SERVICE_CONF_DIR = os.path.join(SYSTEMD_SYSTEM_DIR, '{}.service.d/')

    # Feature state constants
    FEATURE_STATE_ENABLED = "enabled"
    FEATURE_STATE_DISABLED = "disabled"
    FEATURE_STATE_FAILED = "failed"

    def __init__(self, config_db, feature_state_table, device_config):
        self._config_db = config_db
        self._feature_state_table = feature_state_table
        self._device_config = device_config
        self._cached_config = {}
        self.is_multi_npu = device_info.is_multi_npu()

    def handle(self, feature_name, op, feature_cfg):
        if not feature_cfg:
            syslog.syslog(syslog.LOG_INFO, "Deregistering feature {}".format(feature_name))
            self._cached_config.pop(feature_name)
            self._feature_state_table._del(feature_name)
            return

        feature = Feature(feature_name, feature_cfg, self._device_config)
        self._cached_config.setdefault(feature_name, Feature(feature_name, {}))

        # Change auto-restart configuration first.
        # If service reached failed state before this configuration applies (e.g. on boot)
        # the next called self.update_feature_state will start it again. If it will fail
        # again the auto restart will kick-in. Another order may leave it in failed state
        # and not auto restart.
        self.update_feature_auto_restart(feature, feature_name)

        # Enable/disable the container service if the feature state was changed from its previous state.
        if self._cached_config[feature_name].state != feature.state:
            if self.update_feature_state(feature):
                self._cached_config[feature_name].state = feature.state
            else:
                self.resync_feature_state(self._cached_config[feature_name])

    def sync_state_field(self, feature_table):
        """
        Summary:
        Updates the state field in the FEATURE|* tables as the state field
        might have to be rendered based on DEVICE_METADATA table
        """
        for feature_name in feature_table.keys():
            if not feature_name:
                syslog.syslog(syslog.LOG_WARNING, "Feature is None")
                continue

            feature = Feature(feature_name, feature_table[feature_name], self._device_config)

            self._cached_config.setdefault(feature_name, feature)
            self.update_feature_auto_restart(feature, feature_name)
            self.update_feature_state(feature)
            self.resync_feature_state(feature)

    def update_feature_state(self, feature):
        cached_feature = self._cached_config[feature.name]
        enable = False
        disable = False

        # Allowed transitions:
        #  None           -> always_enabled
        #                 -> always_disabled
        #                 -> enabled
        #                 -> disabled
        #  always_enabled -> always_disabled
        #  enabled        -> disabled
        #  disabled       -> enabled
        if cached_feature.state is None:
            enable = feature.state in ("always_enabled", "enabled")
            disable = feature.state in ("always_disabled", "disabled")
        elif cached_feature.state in ("always_enabled", "always_disabled"):
            disable = feature.state == "always_disabled"
            enable = feature.state == "always_enabled"
        elif cached_feature.state in ("enabled", "disabled"):
            enable = feature.state == "enabled"
            disable = feature.state == "disabled"
        else:
            syslog.syslog(syslog.LOG_INFO, "Feature {} service is {}".format(feature.name, cached_feature.state))
            return False

        if not enable and not disable:
            syslog.syslog(syslog.LOG_ERR, "Unexpected state value '{}' for feature {}"
                          .format(feature.state, feature.name))
            return False

        if enable:
            self.enable_feature(feature)
            syslog.syslog(syslog.LOG_INFO, "Feature {} is enabled and started".format(feature.name))

        if disable:
            self.disable_feature(feature)
            syslog.syslog(syslog.LOG_INFO, "Feature {} is stopped and disabled".format(feature.name))

        return True

    def update_feature_auto_restart(self, feature, feature_name):
        dir_name = self.SYSTEMD_SERVICE_CONF_DIR.format(feature_name)
        auto_restart_conf = os.path.join(dir_name, 'auto_restart.conf')

        write_conf = False
        if not os.path.exists(auto_restart_conf):  # if the auto_restart_conf file is not found, set it
            write_conf = True

        if self._cached_config[feature_name].auto_restart != feature.auto_restart:
            write_conf = True

        if not write_conf:
            return

        self._cached_config[feature_name].auto_restart = feature.auto_restart # Update Cache

        restart_config = "always" if feature.auto_restart == "enabled" else "no"
        service_conf = "[Service]\nRestart={}\n".format(restart_config)
        feature_names, feature_suffixes = self.get_feature_attribute(feature)

        for name in feature_names:
            dir_name = self.SYSTEMD_SERVICE_CONF_DIR.format(name)
            auto_restart_conf = os.path.join(dir_name, 'auto_restart.conf')
            if not os.path.exists(dir_name):
                os.mkdir(dir_name)
            with open(auto_restart_conf, 'w') as cfgfile:
                cfgfile.write(service_conf)

        try:
            run_cmd("sudo systemctl daemon-reload", raise_exception=True)
        except Exception as err:
            syslog.syslog(syslog.LOG_ERR, "Feature '{}' failed to configure auto_restart".format(feature.name))
            return

    def get_feature_attribute(self, feature):
        # Create feature name suffix depending feature is running in host or namespace or in both
        feature_names = (
            ([feature.name] if feature.has_global_scope or not self.is_multi_npu else []) +
            ([(feature.name + '@' + str(asic_inst)) for asic_inst in range(device_info.get_num_npus())
                if feature.has_per_asic_scope and self.is_multi_npu])
        )

        if not feature_names:
            syslog.syslog(syslog.LOG_ERR, "Feature '{}' service not available"
                          .format(feature.name))

        feature_suffixes = ["service"] + (["timer"] if feature.has_timer else [])

        return feature_names, feature_suffixes

    def get_systemd_unit_state(self, unit):
        """ Returns service configuration """

        cmd = "sudo systemctl show {} --property UnitFileState".format(unit)
        proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        stdout, stderr = proc.communicate()
        if proc.returncode != 0:
            syslog.syslog(syslog.LOG_ERR, "Failed to get status of {}: rc={} stderr={}".format(unit, proc.returncode, stderr))
            return 'invalid'  # same as systemd's "invalid indicates that it could not be determined whether the unit file is enabled".

        props = dict([line.split("=") for line in stdout.decode().strip().splitlines()])
        return props["UnitFileState"]

    def enable_feature(self, feature):
        cmds = []
        feature_names, feature_suffixes = self.get_feature_attribute(feature)
        for feature_name in feature_names:
            # Check if it is already enabled, if yes skip the system call
            unit_file_state = self.get_systemd_unit_state("{}.{}".format(feature_name, feature_suffixes[-1]))
            if unit_file_state == "enabled":
                continue

            for suffix in feature_suffixes:
                cmds.append("sudo systemctl unmask {}.{}".format(feature_name, suffix))

            # If feature has timer associated with it, start/enable corresponding systemd .timer unit
            # otherwise, start/enable corresponding systemd .service unit

            cmds.append("sudo systemctl enable {}.{}".format(feature_name, feature_suffixes[-1]))
            cmds.append("sudo systemctl start {}.{}".format(feature_name, feature_suffixes[-1]))

            for cmd in cmds:
                syslog.syslog(syslog.LOG_INFO, "Running cmd: '{}'".format(cmd))
                try:
                    run_cmd(cmd, raise_exception=True)
                except Exception as err:
                    syslog.syslog(syslog.LOG_ERR, "Feature '{}.{}' failed to be enabled and started"
                                    .format(feature.name, feature_suffixes[-1]))
                    self.set_feature_state(feature, self.FEATURE_STATE_FAILED)
                    return

        self.set_feature_state(feature, self.FEATURE_STATE_ENABLED)

    def disable_feature(self, feature):
        cmds = []
        feature_names, feature_suffixes = self.get_feature_attribute(feature)
        for feature_name in feature_names:
            # Check if it is already disabled, if yes skip the system call
            unit_file_state = self.get_systemd_unit_state("{}.{}".format(feature_name, feature_suffixes[-1]))
            if unit_file_state in ("disabled", "masked"):
                continue

            for suffix in reversed(feature_suffixes):
                cmds.append("sudo systemctl stop {}.{}".format(feature_name, suffix))
                cmds.append("sudo systemctl disable {}.{}".format(feature_name, feature_suffixes[-1]))
                cmds.append("sudo systemctl mask {}.{}".format(feature_name, feature_suffixes[-1]))
            for cmd in cmds:
                syslog.syslog(syslog.LOG_INFO, "Running cmd: '{}'".format(cmd))
                try:
                    run_cmd(cmd, raise_exception=True)
                except Exception as err:
                    syslog.syslog(syslog.LOG_ERR, "Feature '{}.{}' failed to be stopped and disabled"
                                    .format(feature.name, feature_suffixes[-1]))
                    self.set_feature_state(feature, self.FEATURE_STATE_FAILED)
                    return

        self.set_feature_state(feature, self.FEATURE_STATE_DISABLED)

    def resync_feature_state(self, feature):
        self._config_db.mod_entry('FEATURE', feature.name, {'state': feature.state})

    def set_feature_state(self, feature, state):
        self._feature_state_table.set(feature.name, [('state', state)])


class Iptables(object):
    def __init__(self):
        '''
        Default MSS to 1460 - (MTU 1500 - 40 (TCP/IP Overhead))
        For IPv6, it would be 1440 - (MTU 1500 - 60 octects)
        '''
        self.tcpmss = 1460
        self.tcp6mss = 1440

    def is_ip_prefix_in_key(self, key):
        '''
        Function to check if IP address is present in the key. If it
        is present, then the key would be a tuple or else, it shall be
        be string
        '''
        return (isinstance(key, tuple))

    def load(self, lpbk_table):
        for row in lpbk_table:
            self.iptables_handler(row, lpbk_table[row])

    def command(self, chain, ip, ver, op):
        cmd = 'iptables' if ver == '4' else 'ip6tables'
        cmd += ' -t mangle --{} {} -p tcp --tcp-flags SYN SYN'.format(op, chain)
        cmd += ' -d' if chain == 'PREROUTING' else ' -s'
        mss = self.tcpmss if ver == '4' else self.tcp6mss
        cmd += ' {} -j TCPMSS --set-mss {}'.format(ip, mss)

        return cmd

    def iptables_handler(self, key, data, add=True):
        if not self.is_ip_prefix_in_key(key):
            return

        iface, ip = key
        ip_str = ip.split("/")[0]
        ip_addr = ipaddress.ip_address(ip_str)
        if isinstance(ip_addr, ipaddress.IPv6Address):
            ver = '6'
        else:
            ver = '4'

        self.mangle_handler(ip_str, ver, add)

    def mangle_handler(self, ip, ver, add):
        if not add:
            op = 'delete'
        else:
            op = 'check'

        iptables_cmds = []
        chains = ['PREROUTING', 'POSTROUTING']
        for chain in chains:
            cmd = self.command(chain, ip, ver, op)
            if not add:
                iptables_cmds.append(cmd)
            else:
                '''
                For add case, first check if rule exists. Iptables just appends to the chain
                as a new rule even if it is the same as an existing one. Check this and
                do nothing if rule exists
                '''
                ret = subprocess.call(cmd, shell=True)
                if ret == 0:
                    syslog.syslog(syslog.LOG_INFO, "{} rule exists in {}".format(ip, chain))
                else:
                    # Modify command from Check to Append
                    iptables_cmds.append(cmd.replace("check", "append"))

        for cmd in iptables_cmds:
            syslog.syslog(syslog.LOG_INFO, "Running cmd - {}".format(cmd))
            run_cmd(cmd)


class AaaCfg(object):
    def __init__(self):
        self.authentication_default = {
            'login': 'local',
        }
        self.authorization_default = {
            'login': 'local',
        }
        self.accounting_default = {
            'login': 'disable',
        }
        self.tacplus_global_default = {
            'auth_type': TACPLUS_SERVER_AUTH_TYPE_DEFAULT,
            'timeout': TACPLUS_SERVER_TIMEOUT_DEFAULT,
            'passkey': TACPLUS_SERVER_PASSKEY_DEFAULT
        }
        self.tacplus_global = {}
        self.tacplus_servers = {}

        self.radius_global_default = {
            'priority': 0,
            'auth_port': RADIUS_SERVER_AUTH_PORT_DEFAULT,
            'auth_type': RADIUS_SERVER_AUTH_TYPE_DEFAULT,
            'retransmit': RADIUS_SERVER_RETRANSMIT_DEFAULT,
            'timeout': RADIUS_SERVER_TIMEOUT_DEFAULT,
            'passkey': RADIUS_SERVER_PASSKEY_DEFAULT
        }
        self.radius_global = {}
        self.radius_servers = {}

        self.authentication = {}
        self.authorization = {}
        self.accounting = {}
        self.debug = False
        self.trace = False

        self.hostname = ""

    # Load conf from ConfigDb
    def load(self, aaa_conf, tac_global_conf, tacplus_conf, rad_global_conf, radius_conf):
        for row in aaa_conf:
            self.aaa_update(row, aaa_conf[row], modify_conf=False)
        for row in tac_global_conf:
            self.tacacs_global_update(row, tac_global_conf[row], modify_conf=False)
        for row in tacplus_conf:
            self.tacacs_server_update(row, tacplus_conf[row], modify_conf=False)

        for row in rad_global_conf:
            self.radius_global_update(row, rad_global_conf[row], modify_conf=False)
        for row in radius_conf:
            self.radius_server_update(row, radius_conf[row], modify_conf=False)

        self.modify_conf_file()

    def aaa_update(self, key, data, modify_conf=True):
        if key == 'authentication':
            self.authentication = data
            if 'failthrough' in data:
                self.authentication['failthrough'] = is_true(data['failthrough'])
            if 'debug' in data:
                self.debug = is_true(data['debug'])
        if key == 'authorization':
            self.authorization = data
        if key == 'accounting':
            self.accounting = data
        if modify_conf:
            self.modify_conf_file()

    def pick_src_intf_ipaddrs(self, keys, src_intf):
        new_ipv4_addr = ""
        new_ipv6_addr = ""

        for it in keys:
            if src_intf != it[0] or (isinstance(it, tuple) == False):
                continue
            if new_ipv4_addr != "" and new_ipv6_addr != "":
                break
            ip_str = it[1].split("/")[0]
            ip_addr = ipaddress.IPAddress(ip_str)
            # Pick the first IP address from the table that matches the source interface
            if isinstance(ip_addr, ipaddress.IPv6Address):
                if new_ipv6_addr != "":
                    continue
                new_ipv6_addr = ip_str
            else:
                if new_ipv4_addr != "":
                    continue
                new_ipv4_addr = ip_str

        return(new_ipv4_addr, new_ipv6_addr)

    def tacacs_global_update(self, key, data, modify_conf=True):
        if key == 'global':
            self.tacplus_global = data
            if modify_conf:
                self.modify_conf_file()

    def tacacs_server_update(self, key, data, modify_conf=True):
        if data == {}:
            if key in self.tacplus_servers:
                del self.tacplus_servers[key]
        else:
            self.tacplus_servers[key] = data

        if modify_conf:
            self.modify_conf_file()

    def notify_audisp_tacplus_reload_config(self):
        pid = get_pid("/sbin/audisp-tacplus")
        syslog.syslog(syslog.LOG_INFO, "Found audisp-tacplus PID: {}".format(pid))
        if pid == "":
            return

        # audisp-tacplus will reload TACACS+ config when receive SIGHUP
        try:
            os.kill(int(pid), signal.SIGHUP)
        except Exception as ex:
            syslog.syslog(syslog.LOG_WARNING, "Send SIGHUP to audisp-tacplus failed with exception: {}".format(ex))

    def handle_radius_source_intf_ip_chg(self, key):
        modify_conf=False
        if 'src_intf' in self.radius_global:
            if key[0] == self.radius_global['src_intf']:
                modify_conf=True
        for addr in self.radius_servers:
            if ('src_intf' in self.radius_servers[addr]) and \
                    (key[0] == self.radius_servers[addr]['src_intf']):
                modify_conf=True
                break

        if not modify_conf:
            return

        syslog.syslog(syslog.LOG_INFO, 'RADIUS IP change - key:{}, current server info {}'.format(key, self.radius_servers))
        self.modify_conf_file()

    def handle_radius_nas_ip_chg(self, key):
        modify_conf=False
        # Mgmt IP configuration affects only the default nas_ip
        if 'nas_ip' not in self.radius_global:
            for addr in self.radius_servers:
                if 'nas_ip' not in self.radius_servers[addr]:
                    modify_conf=True
                    break

        if not modify_conf:
            return

        syslog.syslog(syslog.LOG_INFO, 'RADIUS (NAS) IP change - key:{}, current global info {}'.format(key, self.radius_global))
        self.modify_conf_file()

    def radius_global_update(self, key, data, modify_conf=True):
        if key == 'global':
            self.radius_global = data
            if 'statistics' in data:
                self.radius_global['statistics'] = is_true(data['statistics'])
            if modify_conf:
                self.modify_conf_file()

    def radius_server_update(self, key, data, modify_conf=True):
        if data == {}:
            if key in self.radius_servers:
                del self.radius_servers[key]
        else:
            self.radius_servers[key] = data

        if modify_conf:
            self.modify_conf_file()

    def hostname_update(self, hostname, modify_conf=True):
        if self.hostname == hostname:
            return

        self.hostname = hostname

        # Currently only used for RADIUS
        if len(self.radius_servers) == 0:
            return

        if modify_conf:
            self.modify_conf_file()

    def get_hostname(self):
        return self.hostname

    def get_interface_ip(self, source, addr=None):
        keys = None
        try:
            if source.startswith("Eth"):
                if is_vlan_sub_interface(source):
                    keys = self.config_db.get_keys('VLAN_SUB_INTERFACE')
                else:
                    keys = self.config_db.get_keys('INTERFACE')
            elif source.startswith("Po"):
                if is_vlan_sub_interface(source):
                    keys = self.config_db.get_keys('VLAN_SUB_INTERFACE')
                else:
                    keys = self.config_db.get_keys('PORTCHANNEL_INTERFACE')
            elif source.startswith("Vlan"):
                keys = self.config_db.get_keys('VLAN_INTERFACE')
            elif source.startswith("Loopback"):
                keys = self.config_db.get_keys('LOOPBACK_INTERFACE')
            elif source == "eth0":
                keys = self.config_db.get_keys('MGMT_INTERFACE')
        except Exception as e:
            pass

        interface_ip = ""
        if keys != None:
            ipv4_addr, ipv6_addr = self.pick_src_intf_ipaddrs(keys, source)
            # Based on the type of addr, return v4 or v6
            if addr and isinstance(addr, ipaddress.IPv6Address):
                interface_ip = ipv6_addr
            else:
                # This could be tuned, but that involves a DNS query, so
                # offline configuration might trip (or cause delays).
                interface_ip = ipv4_addr
        return interface_ip

    def modify_single_file(self, filename, operations=None):
        if operations:
            cmd = "sed -e {0} {1} > {1}.new; mv -f {1} {1}.old; mv -f {1}.new {1}".format(' -e '.join(operations), filename)
            os.system(cmd)

    def modify_conf_file(self):
        authentication = self.authentication_default.copy()
        authentication.update(self.authentication)
        authorization = self.authorization_default.copy()
        authorization.update(self.authorization)
        accounting = self.accounting_default.copy()
        accounting.update(self.accounting)
        tacplus_global = self.tacplus_global_default.copy()
        tacplus_global.update(self.tacplus_global)
        if 'src_ip' in tacplus_global:
            src_ip = tacplus_global['src_ip']
        else:
            src_ip = None

        servers_conf = []
        if self.tacplus_servers:
            for addr in self.tacplus_servers:
                server = tacplus_global.copy()
                server['ip'] = addr
                server.update(self.tacplus_servers[addr])
                servers_conf.append(server)
            servers_conf = sorted(servers_conf, key=lambda t: int(t['priority']), reverse=True)

        radius_global = self.radius_global_default.copy()
        radius_global.update(self.radius_global)

        # RADIUS: Set the default nas_ip, and nas_id
        if 'nas_ip' not in radius_global:
            nas_ip = self.get_interface_ip("eth0")
            if len(nas_ip) > 0:
                radius_global['nas_ip'] = nas_ip
        if 'nas_id' not in radius_global:
            nas_id = self.get_hostname()
            if len(nas_id) > 0:
                radius_global['nas_id'] = nas_id

        radsrvs_conf = []
        if self.radius_servers:
            for addr in self.radius_servers:
                server = radius_global.copy()
                server['ip'] = addr
                server.update(self.radius_servers[addr])

                if 'src_intf' in server:
                    # RADIUS: Log a message if src_ip is already defined.
                    if 'src_ip' in server:
                        syslog.syslog(syslog.LOG_INFO, \
            "RADIUS_SERVER|{}: src_intf found. Ignoring src_ip".format(addr))
                    # RADIUS: If server.src_intf, then get the corresponding
                    # src_ip based on the server.ip, and set it.
                    src_ip = self.get_interface_ip(server['src_intf'], addr)
                    if len(src_ip) > 0:
                        server['src_ip'] = src_ip
                    elif 'src_ip' in server:
                        syslog.syslog(syslog.LOG_INFO, \
            "RADIUS_SERVER|{}: src_intf has no usable IP addr.".format(addr))
                        del server['src_ip']

                radsrvs_conf.append(server)
            radsrvs_conf = sorted(radsrvs_conf, key=lambda t: int(t['priority']), reverse=True)

        template_file = os.path.abspath(PAM_AUTH_CONF_TEMPLATE)
        env = jinja2.Environment(loader=jinja2.FileSystemLoader('/'), trim_blocks=True)
        env.filters['sub'] = sub
        template = env.get_template(template_file)
        if 'radius' in authentication['login']:
            pam_conf = template.render(debug=self.debug, trace=self.trace, auth=authentication, servers=radsrvs_conf)
        else:
            pam_conf = template.render(auth=authentication, src_ip=src_ip, servers=servers_conf)

        # Use rename(), which is atomic (on the same fs) to avoid empty file
        with open(PAM_AUTH_CONF + ".tmp", 'w') as f:
            f.write(pam_conf)
        os.chmod(PAM_AUTH_CONF + ".tmp", 0o644)
        os.rename(PAM_AUTH_CONF + ".tmp", PAM_AUTH_CONF)

        # Modify common-auth include file in /etc/pam.d/login, sshd.
        # /etc/pam.d/sudo is not handled, because it would change the existing
        # behavior. It can be modified once a config knob is added for sudo.
        if os.path.isfile(PAM_AUTH_CONF):
            self.modify_single_file(ETC_PAMD_SSHD,  [ "'/^@include/s/common-auth$/common-auth-sonic/'" ])
            self.modify_single_file(ETC_PAMD_LOGIN, [ "'/^@include/s/common-auth$/common-auth-sonic/'" ])
        else:
            self.modify_single_file(ETC_PAMD_SSHD,  [ "'/^@include/s/common-auth-sonic$/common-auth/'" ])
            self.modify_single_file(ETC_PAMD_LOGIN, [ "'/^@include/s/common-auth-sonic$/common-auth/'" ])

        # Add tacplus/radius in nsswitch.conf if TACACS+/RADIUS enable
        if 'tacacs+' in authentication['login']:
            if os.path.isfile(NSS_CONF):
                self.modify_single_file(NSS_CONF, [ "'/^passwd/s/ radius//'" ])
                self.modify_single_file(NSS_CONF, [ "'/tacplus/b'", "'/^passwd/s/compat/tacplus &/'", "'/^passwd/s/files/tacplus &/'" ])
        elif 'radius' in authentication['login']:
            if os.path.isfile(NSS_CONF):
                self.modify_single_file(NSS_CONF, [ "'/^passwd/s/tacplus //'" ])
                self.modify_single_file(NSS_CONF, [ "'/radius/b'", "'/^passwd/s/compat/& radius/'", "'/^passwd/s/files/& radius/'" ])
        else:
            if os.path.isfile(NSS_CONF):
                self.modify_single_file(NSS_CONF, [ "'/^passwd/s/tacplus //g'" ])
                self.modify_single_file(NSS_CONF, [ "'/^passwd/s/ radius//'" ])

        # Add tacplus authorization configration in nsswitch.conf
        tacacs_authorization_conf = None
        local_authorization_conf = None
        if 'tacacs+' in authorization['login']:
            tacacs_authorization_conf = "on"
        if 'local' in authorization['login']:
            local_authorization_conf = "on"

        # Add tacplus accounting configration in nsswitch.conf
        tacacs_accounting_conf = None
        local_accounting_conf = None
        if 'tacacs+' in accounting['login']:
            tacacs_accounting_conf = "on"
        if 'local' in accounting['login']:
            local_accounting_conf = "on"

        # Set tacacs+ server in nss-tacplus conf
        template_file = os.path.abspath(NSS_TACPLUS_CONF_TEMPLATE)
        template = env.get_template(template_file)
        nss_tacplus_conf = template.render(
                                        debug=self.debug,
                                        src_ip=src_ip,
                                        servers=servers_conf,
                                        local_accounting=local_accounting_conf,
                                        tacacs_accounting=tacacs_accounting_conf,
                                        local_authorization=local_authorization_conf,
                                        tacacs_authorization=tacacs_authorization_conf)
        with open(NSS_TACPLUS_CONF, 'w') as f:
            f.write(nss_tacplus_conf)

        # Notify auditd plugin to reload tacacs config.
        self.notify_audisp_tacplus_reload_config()

        # Set debug in nss-radius conf
        template_file = os.path.abspath(NSS_RADIUS_CONF_TEMPLATE)
        template = env.get_template(template_file)
        nss_radius_conf = template.render(debug=self.debug, trace=self.trace, servers=radsrvs_conf)
        with open(NSS_RADIUS_CONF, 'w') as f:
            f.write(nss_radius_conf)

        # Create the per server pam_radius_auth.conf
        if radsrvs_conf:
            for srv in radsrvs_conf:
                # Configuration File
                pam_radius_auth_file = RADIUS_PAM_AUTH_CONF_DIR + srv['ip'] + "_" + srv['auth_port'] + ".conf"
                template_file = os.path.abspath(PAM_RADIUS_AUTH_CONF_TEMPLATE)
                template = env.get_template(template_file)
                pam_radius_auth_conf = template.render(server=srv)

                open(pam_radius_auth_file, 'a').close()
                os.chmod(pam_radius_auth_file, 0o600)
                with open(pam_radius_auth_file, 'w+') as f:
                    f.write(pam_radius_auth_conf)

        # Start the statistics service. Only RADIUS implemented
        if ('radius' in authentication['login']) and ('statistics' in radius_global) and \
                radius_global['statistics']:
            cmd = 'service aaastatsd start'
        else:
            cmd = 'service aaastatsd stop'
        syslog.syslog(syslog.LOG_INFO, "cmd - {}".format(cmd))
        try:
            subprocess.check_call(cmd, shell=True)
        except subprocess.CalledProcessError as err:
            syslog.syslog(syslog.LOG_ERR,
                    "{} - failed: return code - {}, output:\n{}"
                    .format(err.cmd, err.returncode, err.output))


class KdumpCfg(object):
    def __init__(self, CfgDb):
        self.config_db = CfgDb
        self.kdump_defaults = { "enabled" : "false",
                                "memory": "0M-2G:256M,2G-4G:320M,4G-8G:384M,8G-:448M",
                                "num_dumps": "3" }

    def load(self, kdump_table):
        """
        Set the KDUMP table in CFG DB to kdump_defaults if not set by the user
        """
        syslog.syslog(syslog.LOG_INFO, "KdumpCfg init ...")
        kdump_conf = kdump_table.get("config", {})
        for row in self.kdump_defaults:
            value = self.kdump_defaults.get(row)
            if not kdump_conf.get(row):
                self.config_db.mod_entry("KDUMP", "config", {row : value})

    def kdump_update(self, key, data):
        syslog.syslog(syslog.LOG_INFO, "Kdump global configuration update")
        if key == "config":
            # Admin mode
            kdump_enabled = self.kdump_defaults["enabled"]
            if data.get("enabled") is not None:
                kdump_enabled = data.get("enabled")
            if kdump_enabled.lower() == "true":
                enabled = True
            else:
                enabled = False
            if enabled:
                run_cmd("sonic-kdump-config --enable")
            else:
                run_cmd("sonic-kdump-config --disable")

            # Memory configuration
            memory = self.kdump_defaults["memory"]
            if data.get("memory") is not None:
                memory = data.get("memory")
            run_cmd("sonic-kdump-config --memory " + memory)

            # Num dumps
            num_dumps = self.kdump_defaults["num_dumps"]
            if data.get("num_dumps") is not None:
                num_dumps = data.get("num_dumps")
            run_cmd("sonic-kdump-config --num_dumps " + num_dumps)

class NtpCfg(object):
    """
    NtpCfg Config Daemon
    1) ntp-config.service handles the configuration updates and then starts ntp.service
    2) Both of them start after all the feature services start
    3) Purpose of this daemon is to propagate runtime config changes in
       NTP, NTP_SERVER and LOOPBACK_INTERFACE
    """
    def __init__(self):
        self.ntp_global = {}
        self.ntp_servers = set()

    def load(self, ntp_global_conf, ntp_server_conf):
        syslog.syslog(syslog.LOG_INFO, "NtpCfg load ...")

        for row in ntp_global_conf:
            self.ntp_global_update(row, ntp_global_conf[row], is_load=True)

        # Force reload on init
        self.ntp_server_update(0, None, is_load=True)

    def handle_ntp_source_intf_chg(self, intf_name):
        # if no ntp server configured, do nothing
        if not self.ntp_servers:
            return

        # check only the intf configured as source interface
        if intf_name not in self.ntp_global.get('src_intf', '').split(';'):
            return
        else:
            # just restart ntp config
            cmd = 'systemctl restart ntp-config'
            run_cmd(cmd)

    def ntp_global_update(self, key, data, is_load=False):
        syslog.syslog(syslog.LOG_INFO, 'NTP GLOBAL Update')
        orig_src = self.ntp_global.get('src_intf', '')
        orig_src_set = set(orig_src.split(";"))
        orig_vrf = self.ntp_global.get('vrf', '')

        new_src = data.get('src_intf', '')
        new_src_set = set(new_src.split(";"))
        new_vrf = data.get('vrf', '')

        # Update the Local Cache
        self.ntp_global = data

        # If initial load don't restart daemon
        if is_load: return

        # check if ntp server configured, if not, do nothing
        if not self.ntp_servers:
            syslog.syslog(syslog.LOG_INFO, "No ntp server when global config change, do nothing")
            return

        if orig_src_set != new_src_set:
            syslog.syslog(syslog.LOG_INFO, "ntp global update for source intf old {} new {}, restarting ntp-config"
                          .format(orig_src_set, new_src_set))
            cmd = 'systemctl restart ntp-config'
            run_cmd(cmd)
        elif new_vrf != orig_vrf:
            syslog.syslog(syslog.LOG_INFO, "ntp global update for vrf old {} new {}, restarting ntp service"
                            .format(orig_vrf, new_vrf))
            cmd = 'service ntp restart'
            run_cmd(cmd)

    def ntp_server_update(self, key, op, is_load=False):
        syslog.syslog(syslog.LOG_INFO, 'ntp server update key {}'.format(key))

        restart_config = False
        if not is_load:
            if op == "SET" and key not in self.ntp_servers:
                restart_config = True
                self.ntp_servers.add(key)
            elif op == "DEL" and key in self.ntp_servers:
                restart_config = True
                self.ntp_servers.remove(key)
        else:
            restart_config = True

        if restart_config:
            cmd = 'systemctl restart ntp-config'
            syslog.syslog(syslog.LOG_INFO, 'ntp server update, restarting ntp-config, ntp servers configured {}'.format(self.ntp_servers))
            run_cmd(cmd)

class PamLimitsCfg(object):
    """
    PamLimit Config Daemon
    1) The pam_limits PAM module sets limits on the system resources that can be obtained in a user-session.
    2) Purpose of this daemon is to render pam_limits config file.
    """
    def __init__(self, config_db):
        self.config_db = config_db
        self.hwsku = ""
        self.type = ""

    # Load config from ConfigDb and render config file/
    def update_config_file(self):
        device_metadata = self.config_db.get_table('DEVICE_METADATA')
        if "localhost" not in device_metadata:
            return

        self.read_localhost_config(device_metadata["localhost"])
        self.render_conf_file()

    # Read localhost config
    def read_localhost_config(self, localhost):
        if "hwsku" in localhost:
            self.hwsku = localhost["hwsku"]
        else:
            self.hwsku = ""

        if "type" in localhost:
            self.type = localhost["type"]
        else:
            self.type = ""

    # Render pam_limits config files
    def render_conf_file(self):
        env = jinja2.Environment(loader=jinja2.FileSystemLoader('/'), trim_blocks=True)
        env.filters['sub'] = sub

        try:
            template_file = os.path.abspath(PAM_LIMITS_CONF_TEMPLATE)
            template = env.get_template(template_file)
            pam_limits_conf = template.render(
                                        hwsku=self.hwsku,
                                        type=self.type)
            with open(PAM_LIMITS_CONF, 'w') as f:
                f.write(pam_limits_conf)

            template_file = os.path.abspath(LIMITS_CONF_TEMPLATE)
            template = env.get_template(template_file)
            limits_conf = template.render(
                                        hwsku=self.hwsku,
                                        type=self.type)
            with open(LIMITS_CONF, 'w') as f:
                f.write(limits_conf)
        except Exception as e:
            syslog.syslog(syslog.LOG_ERR,
                    "modify pam_limits config file failed with exception: {}"
                    .format(e))

class HostConfigDaemon:
    def __init__(self):
        # Just a sanity check to verify if the CONFIG_DB has been initialized
        # before moving forward
        self.config_db = ConfigDBConnector()
        self.config_db.connect(wait_for_init=True, retry_on=True)
        syslog.syslog(syslog.LOG_INFO, 'ConfigDB connect success')

        # Load DEVICE metadata configurations
        self.device_config = {}
        self.device_config['DEVICE_METADATA'] = self.config_db.get_table('DEVICE_METADATA')

        # Load feature state table
        self.state_db_conn = DBConnector(STATE_DB, 0)
        feature_state_table = Table(self.state_db_conn, 'FEATURE')

        # Initialize KDump Config and set the config to default if nothing is provided
        self.kdumpCfg = KdumpCfg(self.config_db)

        # Initialize IpTables
        self.iptables = Iptables()

        # Intialize Feature Handler
        self.feature_handler = FeatureHandler(self.config_db, feature_state_table, self.device_config)

        # Initialize Ntp Config Handler
        self.ntpcfg = NtpCfg()

        self.is_multi_npu = device_info.is_multi_npu()

        # Initialize AAACfg
        self.hostname_cache=""
        self.aaacfg = AaaCfg()

        # Initialize PamLimitsCfg
        self.pamLimitsCfg = PamLimitsCfg(self.config_db)
        self.pamLimitsCfg.update_config_file()

    def load(self, init_data):
        features = init_data['FEATURE']
        aaa = init_data['AAA']
        tacacs_global = init_data['TACPLUS']
        tacacs_server = init_data['TACPLUS_SERVER']
        radius_global = init_data['RADIUS']
        radius_server = init_data['RADIUS_SERVER']
        lpbk_table = init_data['LOOPBACK_INTERFACE']
        ntp_server = init_data['NTP_SERVER']
        ntp_global = init_data['NTP']
        kdump = init_data['KDUMP']

        self.feature_handler.sync_state_field(features)
        self.aaacfg.load(aaa, tacacs_global, tacacs_server, radius_global, radius_server)
        self.iptables.load(lpbk_table)
        self.ntpcfg.load(ntp_global, ntp_server)
        self.kdumpCfg.load(kdump)

        dev_meta = self.config_db.get_table('DEVICE_METADATA')
        if 'localhost' in dev_meta:
            if 'hostname' in dev_meta['localhost']:
                self.hostname_cache = dev_meta['localhost']['hostname']

        # Update AAA with the hostname
        self.aaacfg.hostname_update(self.hostname_cache)

    def __get_intf_name(self, key):
        if isinstance(key, tuple) and key:
            intf = key[0]
        else:
            intf = key
        return intf

    def aaa_handler(self, key, op, data):
        self.aaacfg.aaa_update(key, data)
        syslog.syslog(syslog.LOG_INFO, 'AAA Update: key: {}, op: {}, data: {}'.format(key, op, data))

    def tacacs_server_handler(self, key, op, data):
        self.aaacfg.tacacs_server_update(key, data)
        log_data = copy.deepcopy(data)
        if 'passkey' in log_data:
            log_data['passkey'] = obfuscate(log_data['passkey'])
        syslog.syslog(syslog.LOG_INFO, 'TACPLUS_SERVER update: key: {}, op: {}, data: {}'.format(key, op, log_data))

    def tacacs_global_handler(self, key, op, data):
        self.aaacfg.tacacs_global_update(key, data)
        log_data = copy.deepcopy(data)
        if 'passkey' in log_data:
            log_data['passkey'] = obfuscate(log_data['passkey'])
        syslog.syslog(syslog.LOG_INFO, 'TACPLUS Global update: key: {}, op: {}, data: {}'.format(key, op, log_data))

    def radius_server_handler(self, key, op, data):
        self.aaacfg.radius_server_update(key, data)
        log_data = copy.deepcopy(data)
        if 'passkey' in log_data:
            log_data['passkey'] = obfuscate(log_data['passkey'])
        syslog.syslog(syslog.LOG_INFO, 'RADIUS_SERVER update: key: {}, op: {}, data: {}'.format(key, op, log_data))

    def radius_global_handler(self, key, op, data):
        self.aaacfg.radius_global_update(key, data)
        log_data = copy.deepcopy(data)
        if 'passkey' in log_data:
            log_data['passkey'] = obfuscate(log_data['passkey'])
        syslog.syslog(syslog.LOG_INFO, 'RADIUS Global update: key: {}, op: {}, data: {}'.format(key, op, log_data))

    def mgmt_intf_handler(self, key, op, data):
        key = ConfigDBConnector.deserialize_key(key)
        mgmt_intf_name = self.__get_intf_name(key)
        self.aaacfg.handle_radius_source_intf_ip_chg(mgmt_intf_name)
        self.aaacfg.handle_radius_nas_ip_chg(mgmt_intf_name)

    def lpbk_handler(self, key, op, data):
        key = ConfigDBConnector.deserialize_key(key)
        if op == "DEL":
            add = False
        else:
            add = True

        self.iptables.iptables_handler(key, data, add)
        lpbk_name = self.__get_intf_name(key)
        self.ntpcfg.handle_ntp_source_intf_chg(lpbk_name)
        self.aaacfg.handle_radius_source_intf_ip_chg(key)

    def vlan_intf_handler(self, key, op, data):
        key = ConfigDBConnector.deserialize_key(key)
        self.aaacfg.handle_radius_source_intf_ip_chg(key)

    def vlan_sub_intf_handler(self, key, op, data):
        key = ConfigDBConnector.deserialize_key(key)
        self.aaacfg.handle_radius_source_intf_ip_chg(key)

    def portchannel_intf_handler(self, key, op, data):
        key = ConfigDBConnector.deserialize_key(key)
        self.aaacfg.handle_radius_source_intf_ip_chg(key)

    def phy_intf_handler(self, key, op, data):
        key = ConfigDBConnector.deserialize_key(key)
        self.aaacfg.handle_radius_source_intf_ip_chg(key)

    def ntp_server_handler(self, key, op, data):
        self.ntpcfg.ntp_server_update(key, op)

    def ntp_global_handler(self, key, op, data):
        self.ntpcfg.ntp_global_update(key, data)

    def kdump_handler (self, key, op, data):
        syslog.syslog(syslog.LOG_INFO, 'Kdump handler...')
        self.kdumpCfg.kdump_update(key, data)

    def wait_till_system_init_done(self):
        # No need to print the output in the log file so using the "--quiet"
        # flag
        systemctl_cmd = "sudo systemctl is-system-running --wait --quiet"
        subprocess.call(systemctl_cmd, shell=True)

    def register_callbacks(self):

        def make_callback(func):
            def callback(table, key, data):
                if data is None:
                    op = "DEL"
                else:
                    op = "SET"
                return func(key, op, data)
            return callback

        self.config_db.subscribe('KDUMP', make_callback(self.kdump_handler))
        # Handle FEATURE updates before other tables
        self.config_db.subscribe('FEATURE', make_callback(self.feature_handler.handle))
        # Handle AAA, TACACS and RADIUS related tables
        self.config_db.subscribe('AAA', make_callback(self.aaa_handler))
        self.config_db.subscribe('TACPLUS', make_callback(self.tacacs_global_handler))
        self.config_db.subscribe('TACPLUS_SERVER', make_callback(self.tacacs_server_handler))
        self.config_db.subscribe('RADIUS', make_callback(self.radius_global_handler))
        self.config_db.subscribe('RADIUS_SERVER', make_callback(self.radius_server_handler))
        # Handle IPTables configuration
        self.config_db.subscribe('LOOPBACK_INTERFACE', make_callback(self.lpbk_handler))
        # Handle NTP & NTP_SERVER updates
        self.config_db.subscribe('NTP', make_callback(self.ntp_global_handler))
        self.config_db.subscribe('NTP_SERVER', make_callback(self.ntp_server_handler))
        # Handle updates to src intf changes in radius
        self.config_db.subscribe('MGMT_INTERFACE', make_callback(self.mgmt_intf_handler))
        self.config_db.subscribe('VLAN_INTERFACE', make_callback(self.vlan_intf_handler))
        self.config_db.subscribe('VLAN_SUB_INTERFACE', make_callback(self.vlan_sub_intf_handler))
        self.config_db.subscribe('PORTCHANNEL_INTERFACE', make_callback(self.portchannel_intf_handler))
        self.config_db.subscribe('INTERFACE', make_callback(self.phy_intf_handler))
        
        syslog.syslog(syslog.LOG_INFO,
                      "Waiting for systemctl to finish initialization")
        self.wait_till_system_init_done()
        syslog.syslog(syslog.LOG_INFO,
                      "systemctl has finished initialization -- proceeding ...")

    def start(self):
        self.config_db.listen(init_data_handler=self.load)


def main():
    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGHUP, signal_handler)
    daemon = HostConfigDaemon()
    daemon.register_callbacks()
    daemon.start()

if __name__ == "__main__":
    main()

